{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parte 10: Aprendizaje Federado con Agregación Encriptada de Gradientes\n",
    "\n",
    "En las últimas secciones hemos aprendido sobre el cómputo encriptado construyendo varios programas simples. En esta sección, regresaremos al [Demo de Aprendizaje Federado de la parte 4](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/translations/espa%C3%B1ol/Parte%2004%20-%20Aprendizaje%20Federado%20via%20un%20Agregador%20Confiable.ipynb), donde teníamos un \"agregador confiable\" quien es el responsable de promediar las actualizaciones de los modelos de varios trabajadores.\n",
    "\n",
    "Ahora vamos a usar nuestras nuevas herramientas de cómputo encriptado para dispensar este agregador confiable ya que no es ideal tenerlo porque asume que podemos encontrar a alguien lo suficientemente confiable para que tenga acceso a esta información sensible. Esto no siempre es el caso.\n",
    "\n",
    "Por lo tanto, en este notebook mostraremos cómo podemos usar la computación segura multi-parte (CSMP) para realizar una agregación segura de tal manera que necesitemos un \"agregador seguro\".\n",
    "\n",
    "Autores:\n",
    "- Theo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel)\n",
    "- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    "\n",
    "Traductores:\n",
    "- Arturo Márquez Flores - Twitter: [@arturomf94](https://twitter.com/arturomf94) \n",
    "- Ricardo Pretelt - Twitter: [@ricardopretelt](https://twitter.com/ricardopretelt)\n",
    "- Carlos Salgado - Github: [@socd06](https://github.com/socd06) \n",
    "- Daniel Firebanks-Quevedo - GitHub: [@thefirebanks](https://www.github.com/thefirebanks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sección 1: Aprendizaje Federado Normal\n",
    "\n",
    "Primero, aquí hay código que realiza un aprendizaje federado clásico en el conjunto de datos Boston Housing. Esta sección del código puede desglosarse en varias secciones.\n",
    "\n",
    "### Configuración"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "class Parser:\n",
    "    \"\"\"Parámetros para el entrenamiento\"\"\"\n",
    "    def __init__(self):\n",
    "        self.epochs = 10\n",
    "        self.lr = 0.001\n",
    "        self.test_batch_size = 8\n",
    "        self.batch_size = 8\n",
    "        self.log_interval = 10\n",
    "        self.seed = 1\n",
    "    \n",
    "args = Parser()\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "kwargs = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cargar los Datos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/BostonHousing/boston_housing.pickle','rb') as f:\n",
    "    ((X, y), (X_test, y_test)) = pickle.load(f)\n",
    "\n",
    "X = torch.from_numpy(X).float()\n",
    "y = torch.from_numpy(y).float()\n",
    "X_test = torch.from_numpy(X_test).float()\n",
    "y_test = torch.from_numpy(y_test).float()\n",
    "# preprocesamiento\n",
    "mean = X.mean(0, keepdim=True)\n",
    "dev = X.std(0, keepdim=True)\n",
    "mean[:, 3] = 0. # la columna 3 es binaria\n",
    "dev[:, 3] = 1.  # así que no la estandarizamos\n",
    "X = (X - mean) / dev\n",
    "X_test = (X_test - mean) / dev\n",
    "train = TensorDataset(X, y)\n",
    "test = TensorDataset(X_test, y_test)\n",
    "train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)\n",
    "test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estructura de la Red Neuronal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(13, 32)\n",
    "        self.fc2 = nn.Linear(32, 24)\n",
    "        self.fc3 = nn.Linear(24, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 13)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "model = Net()\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Enganche de Pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "\n",
    "hook = sy.TorchHook(torch)\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "james = sy.VirtualWorker(hook, id=\"james\")\n",
    "\n",
    "compute_nodes = [bob, alice]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Mandar los datos a los trabajadores** <br>\n",
    "Usualmente ya lo tendrían, eso sólo es para el demo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_distributed_dataset = []\n",
    "\n",
    "for batch_idx, (data,target) in enumerate(train_loader):\n",
    "    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    train_distributed_dataset.append((data, target))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Función de Entrenamiento"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data,target) in enumerate(train_distributed_dataset):\n",
    "        worker = data.location\n",
    "        model.send(worker)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        # actualiza el modelo\n",
    "        pred = model(data)\n",
    "        loss = F.mse_loss(pred.view(-1), target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.get()\n",
    "            \n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            loss = loss.get()\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * data.shape[0], len(train_loader),\n",
    "                       100. * batch_idx / len(train_loader), loss.item()))\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Función para Pruebas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    for data, target in test_loader:\n",
    "        output = model(data)\n",
    "        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # suma la pérdida\n",
    "        pred = output.data.max(1, keepdim=True)[1] # obtén el índice del máximo de la probabilidad logarítmica\n",
    "        \n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}\\n'.format(test_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Entrenando el Modelo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(epoch)\n",
    "\n",
    "    \n",
    "total_time = time.time() - t\n",
    "print('Total', round(total_time, 2), 's')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculando el Desempeño"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sección 2: Añadiendo la Agregación Encriptada\n",
    "\n",
    "Ahora vamos a modificar este ejemplo sutilmente para agregar los gradientes de manera encriptada. La diferencia principal está en las líneas 1 o 2 del código en la función `train()`, que mostraremos. Por el momento, vamos a reprocesar nuestros datos e inicializar el modelo para bob y alice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "remote_dataset = (list(),list())\n",
    "\n",
    "train_distributed_dataset = []\n",
    "\n",
    "for batch_idx, (data,target) in enumerate(train_loader):\n",
    "    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))\n",
    "\n",
    "def update(data, target, model, optimizer):\n",
    "    model.send(data.location)\n",
    "    optimizer.zero_grad()\n",
    "    pred = model(data)\n",
    "    loss = F.mse_loss(pred.view(-1), target)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return model\n",
    "\n",
    "bobs_model = Net()\n",
    "alices_model = Net()\n",
    "\n",
    "bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)\n",
    "alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)\n",
    "\n",
    "models = [bobs_model, alices_model]\n",
    "params = [list(bobs_model.parameters()), list(alices_model.parameters())]\n",
    "optimizers = [bobs_optimizer, alices_optimizer]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construyendo la Lógica de Entrenamiento\n",
    "\n",
    "La única diferencia **real** está dentro del método de entrenamiento. Vamos a ver esto paso por paso.\n",
    "\n",
    "### Parte A: Entrenamiento:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# esto selecciona el batch que entrenaremos\n",
    "data_index = 0\n",
    "# actualiza los modelos remotos\n",
    "# podríamos iterar esto múltiples veces antes de proceder, pero sólo vamos a hacer una iteración por trabajador aquí\n",
    "for remote_index in range(len(compute_nodes)):\n",
    "    data, target = remote_dataset[remote_index][data_index]\n",
    "    models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parte B: Agregación Encriptada"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# crea una lista donde depositaremos nuestro modelo modelo promedio encriptado\n",
    "new_params = list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# itera sobre cada parámetro\n",
    "for param_i in range(len(params[0])):\n",
    "\n",
    "    # para cada trabajador\n",
    "    spdz_params = list()\n",
    "    for remote_index in range(len(compute_nodes)):\n",
    "        \n",
    "        # selecciona el parámetro idéntico para cada trabajador y copia\n",
    "        copy_of_parameter = params[remote_index][param_i].copy()\n",
    "        \n",
    "        # como la CSMP sólo puede trabajar con enteros (sin puntos flotantes), necesitamos\n",
    "        # utilizar enteros para guardar la información decimal. En otras palabras, necesitamos\n",
    "        # usar una codificación con precisión fija.\n",
    "        fixed_precision_param = copy_of_parameter.fix_precision()\n",
    "        \n",
    "        # ahora encriptamos esto en una máquina remota. Nota que\n",
    "        # fixed_precision_param ya es un puntero. Entonces, cuando\n",
    "        # llamamos share encripta los datos a los que se apunta. Esto\n",
    "        # regresa un puntero al objeto secreto compartido en el CMP,\n",
    "        # que necesitamos tomar.\n",
    "        encrypted_param = fixed_precision_param.share(bob, alice, crypto_provider=james)\n",
    "        \n",
    "        # ahora tomamos el puntero\n",
    "        param = encrypted_param.get()\n",
    "        \n",
    "        # guarda el parámetro para promediarlo con el mismo parámetro de\n",
    "        # los otros trabajadores\n",
    "        spdz_params.append(param)\n",
    "\n",
    "    # promedia params con múltiples trabajadores, tómalos a la máquina local\n",
    "    # desencripta y decodifica (de la precisión fija) al número de punto flotante\n",
    "    new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2\n",
    "    \n",
    "    # guarda en nuevo parámetro promediado\n",
    "    new_params.append(new_param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parte C: Limpieza"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    for model in params:\n",
    "        for param in model:\n",
    "            param *= 0\n",
    "\n",
    "    for model in models:\n",
    "        model.get()\n",
    "\n",
    "    for remote_index in range(len(compute_nodes)):\n",
    "        for param_index in range(len(params[remote_index])):\n",
    "            params[remote_index][param_index].set_(new_params[param_index])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ¡Ahora lo juntamos!\n",
    "\n",
    "Y ahora que conocemos cada paso, podemos juntarlo en un ciclo de entrenamiento."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    for data_index in range(len(remote_dataset[0])-1):\n",
    "        # actualiza los modelos remotos\n",
    "        for remote_index in range(len(compute_nodes)):\n",
    "            data, target = remote_dataset[remote_index][data_index]\n",
    "            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])\n",
    "\n",
    "        # agregación encriptada\n",
    "        new_params = list()\n",
    "        for param_i in range(len(params[0])):\n",
    "            spdz_params = list()\n",
    "            for remote_index in range(len(compute_nodes)):\n",
    "                spdz_params.append(params[remote_index][param_i].copy().fix_precision().share(bob, alice, crypto_provider=james).get())\n",
    "\n",
    "            new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2\n",
    "            new_params.append(new_param)\n",
    "\n",
    "        # limpieza\n",
    "        with torch.no_grad():\n",
    "            for model in params:\n",
    "                for param in model:\n",
    "                    param *= 0\n",
    "\n",
    "            for model in models:\n",
    "                model.get()\n",
    "\n",
    "            for remote_index in range(len(compute_nodes)):\n",
    "                for param_index in range(len(params[remote_index])):\n",
    "                    params[remote_index][param_index].set_(new_params[param_index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    models[0].eval()\n",
    "    test_loss = 0\n",
    "    for data, target in test_loader:\n",
    "        output = models[0](data)\n",
    "        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # suma la pérdida\n",
    "        pred = output.data.max(1, keepdim=True)[1] # obtén el índice del máximo de la probabilidad logarítmica\n",
    "        \n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('Test set: Average loss: {:.4f}\\n'.format(test_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "\n",
    "for epoch in range(args.epochs):\n",
    "    print(f\"Epoch {epoch + 1}\")\n",
    "    train(epoch)\n",
    "    test()\n",
    "\n",
    "    \n",
    "total_time = time.time() - t\n",
    "print('Total', round(total_time, 2), 's')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# !Felicitaciones! - !Es hora de unirte a la comunidad!\n",
    "\n",
    "¡Felicitaciones por completar esta parte del tutorial! Si te gustó y quieres unirte al movimiento para preservar la privacidad, propiedad descentralizada de IA y la cadena de suministro de IA (datos), puedes hacerlo de las ¡siguientes formas!\n",
    "\n",
    "### Dale una estrella a PySyft en GitHub\n",
    "\n",
    "La forma más fácil de ayudar a nuestra comunidad es por darle estrellas a ¡los repositorios de Github! Esto ayuda a crear consciencia de las interesantes herramientas que estamos construyendo.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### ¡Únete a nuestro Slack!\n",
    "\n",
    "La mejor manera de mantenerte actualizado con los últimos avances es ¡unirte a la comunidad! Tú lo puedes hacer llenando el formulario en [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### ¡Únete a un proyecto de código!\n",
    "\n",
    "La mejor manera de contribuir a nuestra comunidad es convertirte en un ¡contribuidor de código! En cualquier momento puedes ir al _Github Issues_ de PySyft y filtrar por \"Proyectos\". Esto mostrará todos los tiquetes de nivel superior dando un resumen de los proyectos a los que ¡te puedes unir! Si no te quieres unir a un proyecto, pero quieres hacer un poco de código, también puedes mirar más mini-proyectos \"de una persona\" buscando por Github Issues con la etiqueta \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donar\n",
    "\n",
    "Si no tienes tiempo para contribuir a nuestra base de código, pero quieres ofrecer tu ayuda, también puedes aportar a nuestro *Open Collective\"*. Todas las donaciones van a nuestro *web hosting* y otros gastos de nuestra comunidad como ¡hackathons y meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
